
from __future__ import annotations
import math, json, time, hashlib
from dataclasses import dataclass, asdict
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd

# Optional deps (graceful fallback)
try:
    import yaml  # type: ignore
    HAVE_YAML = True
except Exception:
    yaml = None  # type: ignore
    HAVE_YAML = False

try:
    from numba import njit  # noqa: F401
    NUMBA_AVAILABLE = True
except Exception:
    NUMBA_AVAILABLE = False

@dataclass
class GridSpec:
    Lx: int = 509
    Ly: int = 481
    screen_x: int = 508
    slit_y_up: int = +10
    slit_y_dn: int = -10
    band_increments: Tuple[int, int, int, int, int] = (5, 9, 13, 17, 21)
    band_height: int = 20
    slant_alpha: float = 0.0  # set to 1/19 to enable slant bands (E3)

@dataclass
class InstrumentSpec:
    M: int = 24
    jitter_halfwidth: float = 0.15  # J in periods
    epsilon: float = 0.10           # constructive half-window (non-overlapping when <=0.25)
    roi_halfwidth: int = 80
    smooth_box_w: int = 9           # box smoothing width (px)

@dataclass
class SweepSpec:
    gravities: Tuple[int, ...] = (1, 0, 2)  # run g=1 first (main), then g=0,2
    s_ww_values: Tuple[float, ...] = (0.00, 0.25, 0.50, 0.75, 0.90)
    seeds: Tuple[int, ...] = (101, 202, 303)
    K_profiles: int = 8_000
    K_summary: int = 50_000

@dataclass
class Flags:
    use_numba: bool = True
    use_dijkstra_4nbr: bool = False  # default: Bresenham straight
    keep_profiles_full_range: bool = True
    analytic_fast_mode: bool = False

def row_to_y(i: int, Ly: int) -> int:
    return i - (Ly // 2)

def y_to_row(y: int, Ly: int) -> int:
    return y + (Ly // 2)

def circle_dist_vec(phi: np.ndarray, center: float) -> np.ndarray:
    return np.abs(((phi - center + 0.5) % 1.0) - 0.5)

def rng_for(y: int, g: int, sww: float, seed: int) -> np.random.Generator:
    h = hashlib.sha256(f"{y}|{g}|{sww:.6f}|{seed}".encode()).digest()
    s = int.from_bytes(h[:8], 'little', signed=False) % (2**63 - 1)
    return np.random.default_rng(s)

def build_acts_grid(grid: GridSpec, instr: InstrumentSpec, g: int) -> np.ndarray:
    acts = np.full((grid.Ly, grid.Lx), instr.M, dtype=np.int32)
    if g == 0:
        return acts
    inc = np.array(grid.band_increments, dtype=np.int32)
    max_band = len(inc)
    rows = np.arange(grid.Ly, dtype=np.int32)
    y_vals = rows - (grid.Ly // 2)
    if grid.slant_alpha == 0.0:
        yprime = np.repeat(y_vals[:, None], grid.Lx, axis=1)
    else:
        yprime = (y_vals[:, None] - grid.slant_alpha * np.arange(grid.Lx)[None, :])
    below = (yprime < 0.0)
    drows = np.where(below, -yprime, 0.0)
    band_idx = np.minimum((drows // grid.band_height).astype(np.int32) + 1, max_band)
    band_idx = np.where(below, band_idx, 0)
    inc_per_cell = np.zeros_like(band_idx, dtype=np.int32)
    for b in range(1, max_band + 1):
        inc_per_cell[band_idx == b] = inc[b - 1] * g
    acts = acts + inc_per_cell
    return acts

Ly_global = 0

def bresenham_cells(x0: int, y0: int, x1: int, y1: int) -> List[Tuple[int, int]]:
    def y_to_row_local(y: int, Ly: int) -> int:
        return y + (Ly // 2)
    points: List[Tuple[int,int]] = []
    dx = x1 - x0
    dy = y1 - y0
    x_step = 1 if dx >= 0 else -1
    y_step = 1 if dy >= 0 else -1
    dx = abs(dx)
    dy = abs(dy)
    x, y = x0, y0
    if dx >= dy:
        err = dx // 2
        for _ in range(dx + 1):
            points.append((x, y_to_row_local(y, Ly_global)))
            x += x_step
            err -= dy
            if err < 0:
                y += y_step
                err += dx
    else:
        err = dy // 2
        for _ in range(dy + 1):
            points.append((x, y_to_row_local(y, Ly_global)))
            y += y_step
            err -= dx
            if err < 0:
                x += x_step
                err += dy
    return points

def path_time_ticks(acts: np.ndarray, x0: int, y0: int, x1: int, y1: int, M: int) -> float:
    global Ly_global
    Ly_global = acts.shape[0]
    cells = bresenham_cells(x0, y0, x1, y1)
    total_acts = 0
    for (cx, r) in cells:
        if 0 <= r < acts.shape[0] and 0 <= cx < acts.shape[1]:
            total_acts += int(acts[r, cx])
    return total_acts / float(M)

def precompute_deltaT_y(grid: GridSpec, instr: InstrumentSpec, acts: np.ndarray):
    ys = np.arange(-(grid.Ly // 2), +(grid.Ly - grid.Ly // 2), dtype=np.int32)
    t_up = np.zeros_like(ys, dtype=np.float64)
    t_dn = np.zeros_like(ys, dtype=np.float64)
    for idx, y in enumerate(ys):
        t_up[idx] = path_time_ticks(acts, 0, grid.slit_y_up, grid.screen_x, int(y), instr.M)
        t_dn[idx] = path_time_ticks(acts, 0, grid.slit_y_dn, grid.screen_x, int(y), instr.M)
    deltaT = t_up - t_dn
    return ys, t_up, t_dn, deltaT

def simulate_trials_for_y(phi: float, sww: float, K: int, J: float, eps: float, rng: np.random.Generator):
    u = rng.random(K)
    upper_pass = (u >= sww)
    xi = rng.uniform(-J, +J, size=K)
    phi_j = (phi + xi) % 1.0
    c0 = circle_dist_vec(phi_j, 0.0) <= eps
    c1 = circle_dist_vec(phi_j, 0.5) <= eps
    both = upper_pass & c0 & c1
    n_port0 = 0
    n_port1 = 0
    n_neutral = 0
    n_born = 0
    hit0 = upper_pass & c0 & (~c1)
    hit1 = upper_pass & c1 & (~c0)
    n_tie_c0 = int(hit0.sum())
    n_tie_c1 = int(hit1.sum())
    n_port0 += n_tie_c0
    n_port1 += n_tie_c1
    two_neutral = upper_pass & (~c0) & (~c1)
    n_two_neutral = int(two_neutral.sum())
    if n_two_neutral > 0:
        coin = rng.integers(0, 2, size=n_two_neutral)
        n_port0 += int((coin == 0).sum())
        n_port1 += int((coin == 1).sum())
        n_neutral += n_two_neutral
    blocked = ~upper_pass
    n_blocked = int(blocked.sum())
    if n_blocked > 0:
        coin = rng.integers(0, 2, size=n_blocked)
        n_port0 += int((coin == 0).sum())
        n_port1 += int((coin == 1).sum())
        n_neutral += n_blocked
    n_both = int(both.sum())
    if n_both > 0:
        coin = rng.integers(0, 2, size=n_both)
        n_port0 += int((coin == 0).sum())
        n_port1 += int((coin == 1).sum())
        n_born += n_both
    assert (n_port0 + n_port1) == K, "Port counts must sum to K"
    return n_port0, n_port1, n_tie_c0, n_tie_c1, n_neutral, n_born

def box_smooth(v: np.ndarray, w: int) -> np.ndarray:
    if w <= 1: return v.copy()
    if w % 2 == 0: w += 1
    kernel = np.ones(w) / w
    return np.convolve(v, kernel, mode='same')

def compute_visibility(I: np.ndarray, ys: np.ndarray, roi_halfwidth: int):
    mask = (ys >= -roi_halfwidth) & (ys <= roi_halfwidth)
    I_roi = I[mask]
    Imax = float(I_roi.max())
    Imin = float(I_roi.min())
    V = (Imax - Imin) / (Imax + Imin) if (Imax + Imin) > 0 else 0.0
    return V, Imax, Imin

def run_profiles_and_summary(grid: GridSpec, instr: InstrumentSpec, sweep: SweepSpec, flags: Flags):
    t0 = time.time()
    profiles_rows: List[Dict] = []
    summary_rows: List[Dict] = []
    audit = {
        "curve_lint": {},
        "no_skip": True,
        "pf_born_ties_only": True,
        "born_invocations": 0,
        "englert_check": {},
        "params": {},
        "notes": {"windows_overlap": (2*instr.epsilon >= 0.5)}
    }
    all_y = np.arange(-(grid.Ly // 2), +(grid.Ly - grid.Ly // 2), dtype=np.int32)
    roi_mask = (all_y >= -instr.roi_halfwidth) & (all_y <= instr.roi_halfwidth)
    roi_y = all_y[roi_mask]

    for g in sweep.gravities:
        acts = build_acts_grid(grid, instr, g)
        ys, t_up, t_dn, deltaT = precompute_deltaT_y(grid, instr, acts)
        assert np.all(ys == all_y), "y grid mismatch"
        phi_y = (deltaT % 1.0)

        # Profiles
        y_iter = all_y if flags.keep_profiles_full_range else roi_y
        for sww in sweep.s_ww_values:
            acc = {k: np.zeros_like(y_iter, dtype=np.int64) for k in ["port0","port1","tie_c0","tie_c1","neutral","born"]}
            Kp = sweep.K_profiles
            for seed in sweep.seeds:
                for iy, y in enumerate(y_iter):
                    rng = rng_for(int(y), g, sww, seed)
                    n0, n1, nc0, nc1, nneu, nborn = simulate_trials_for_y(
                        phi=float(phi_y[y + (grid.Ly // 2)]),
                        sww=float(sww), K=Kp, J=instr.jitter_halfwidth, eps=instr.epsilon, rng=rng
                    )
                    acc["port0"][iy] += n0; acc["port1"][iy] += n1
                    acc["tie_c0"][iy] += nc0; acc["tie_c1"][iy] += nc1
                    acc["neutral"][iy] += nneu; acc["born"][iy] += nborn
            Ktot = Kp * len(sweep.seeds)
            for iy, y in enumerate(y_iter):
                p0 = acc["port0"][iy] / Ktot
                p1 = acc["port1"][iy] / Ktot
                bias = p0 - p1
                tie_rate = (acc["tie_c0"][iy] + acc["tie_c1"][iy]) / Ktot
                neutral = acc["neutral"][iy] / Ktot
                profiles_rows.append({
                    "y": int(y), "g": int(g), "s_ww": float(sww),
                    "p0": float(p0), "p1": float(p1), "bias": float(bias),
                    "tie_rate": float(tie_rate), "neutral": float(neutral)
                })
            audit["born_invocations"] += int(acc["born"].sum())

        # Summary (ROI)
        for sww in sweep.s_ww_values:
            Ksum = sweep.K_summary
            p0_map = np.zeros_like(roi_y, dtype=np.float64)
            p1_map = np.zeros_like(roi_y, dtype=np.float64)
            tie_map = np.zeros_like(roi_y, dtype=np.float64)
            neutral_map = np.zeros_like(roi_y, dtype=np.float64)
            for seed in sweep.seeds:
                for iy, y in enumerate(roi_y):
                    rng = rng_for(int(y), g, sww, seed)
                    n0, n1, nc0, nc1, nneu, nborn = simulate_trials_for_y(
                        phi=float(phi_y[y + (grid.Ly // 2)]),
                        sww=float(sww), K=Ksum, J=instr.jitter_halfwidth, eps=instr.epsilon, rng=rng
                    )
                    Ktot = float(Ksum)
                    p0_map[iy] += n0 / Ktot
                    p1_map[iy] += n1 / Ktot
                    tie_map[iy] += (nc0 + nc1) / Ktot
                    neutral_map[iy] += nneu / Ktot
            p0_map /= len(sweep.seeds); p1_map /= len(sweep.seeds)
            tie_map /= len(sweep.seeds); neutral_map /= len(sweep.seeds)
            I = 0.5 + 0.5 * (p0_map - p1_map)
            I_s = box_smooth(I, instr.smooth_box_w)
            V, Imax, Imin = compute_visibility(I_s, roi_y, instr.roi_halfwidth)
            audit.setdefault("_tie_integrals", {})
            tie_int = float(tie_map.sum())
            audit["_tie_integrals"][f"{g},{sww}"] = tie_int
            summary_rows.append({
                "g": g, "s_ww": sww, "V": V, "Imax": Imax, "Imin": Imin,
                "tie_rate_center": float(tie_map[list(roi_y).index(0)] if 0 in set(roi_y) else float('nan')),
                "neutral_center": float(neutral_map[list(roi_y).index(0)] if 0 in set(roi_y) else float('nan')),
                "K": int(Ksum * len(sweep.seeds)), "J": instr.jitter_halfwidth,
                "epsilon": instr.epsilon, "M": instr.M, "seeds": ",".join(map(str, sweep.seeds)),
                "_tie_int": tie_int
            })

    profiles_df = pd.DataFrame(profiles_rows)
    summary_df = pd.DataFrame(summary_rows)
    D_vals: List[Tuple[int, float]] = []
    for g_val in summary_df["g"].unique():
        denom = summary_df[(summary_df["g"] == g_val) & (summary_df["s_ww"] == 0.0)]["_tie_int"]
        base = float(denom.iloc[0]) if len(denom) == 1 and denom.iloc[0] > 0 else float('nan')
        for idx, row in summary_df[summary_df["g"] == g_val].iterrows():
            tie_int = float(row["_tie_int"])
            D = 1.0 - (tie_int / base) if (base and base > 0) else float('nan')
            D_vals.append((idx, D))
    for idx, D in D_vals:
        summary_df.loc[idx, "D"] = D

    # Monotonicity audit
    for g_val in summary_df["g"].unique():
        sub = summary_df[summary_df["g"]==g_val].sort_values("s_ww")
        Vs = sub["V"].to_numpy()
        Ds = sub["D"].to_numpy()
        monotone = True
        tol = 5e-3
        for i in range(len(Vs)-1):
            if Vs[i+1] - Vs[i] > tol:
                monotone = False
                break
        audit["curve_lint"][str(g_val)] = {
            "monotone_nonincreasing_V_vs_sww": bool(monotone),
            "s_ww": sub["s_ww"].tolist(),
            "V": sub["V"].round(6).tolist(),
            "D": sub["D"].round(6).tolist()
        }

    # Englert check
    for g_val in summary_df["g"].unique():
        sub = summary_df[summary_df["g"]==g_val]
        excess = [(float(r["V"])**2 + float(r["D"])**2) - 1.0 for _, r in sub.iterrows()]
        audit["englert_check"][str(g_val)] = {"max_excess": float(max(excess))}

    audit["pf_born_ties_only"] = (audit["born_invocations"] == 0)
    audit["params"] = {
        "grid": asdict(grid), "instrument": asdict(instr),
        "sweep": asdict(sweep), "flags": asdict(flags)
    }
    audit["runtime_sec"] = float(time.time() - t0)
    summary_df = summary_df.drop(columns=["_tie_int"])

    return profiles_df, summary_df, audit

def write_manifest_yaml(path: str, grid: GridSpec, instr: InstrumentSpec, sweep: SweepSpec, flags: Flags):
    manifest = {"grid": asdict(grid), "instrument": asdict(instr), "sweep": asdict(sweep), "flags": asdict(flags)}
    if HAVE_YAML:
        with open(path, "w") as f:
            yaml.safe_dump(manifest, f, sort_keys=False)  # type: ignore
    else:
        # Fallback: write JSON if YAML not available
        with open(path.rsplit(".",1)[0] + ".json", "w") as f:
            json.dump(manifest, f, indent=2)

def write_audit_json(path: str, audit: Dict):
    with open(path, "w") as f:
        json.dump(audit, f, indent=2, sort_keys=False)
